import math
import torch
import torch.nn as nn
try: 
    from chrono_initialization import init_gru as chrono_init
except Exception: 
    chrono_init = None  # If chrono_initialization.py is not available


class LayerNormGRUCell(torch.nn.Module):
    def __init__(self, input_size, hidden_size, bias=True):
        super(LayerNormGRUCell, self).__init__()

        self.ln_i2h = torch.nn.LayerNorm(2 * hidden_size, elementwise_affine=False)
        self.ln_h2h = torch.nn.LayerNorm(2 * hidden_size, elementwise_affine=False)
        self.ln_cell_1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.ln_cell_2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False)

        self.i2h = torch.nn.Linear(input_size, 2 * hidden_size, bias=bias)
        self.h2h = torch.nn.Linear(hidden_size, 2 * hidden_size, bias=bias)
        self.h_hat_W = torch.nn.Linear(input_size, hidden_size, bias=bias)
        self.h_hat_U = torch.nn.Linear(hidden_size, hidden_size, bias=bias)

        self.hidden_size = hidden_size
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, h):
        # x, h: (B, hidden or input)
        h = h.view(h.size(0), -1)
        x = x.view(x.size(0), -1)

        # Linear mappings
        i2h = self.i2h(x)
        h2h = self.h2h(h)

        # Layer norm
        i2h = self.ln_i2h(i2h)
        h2h = self.ln_h2h(h2h)

        preact = i2h + h2h 

        # activations
        gates = preact.sigmoid()
        z_t = gates[:, :self.hidden_size]
        r_t = gates[:, -self.hidden_size:]

        # h_hat
        h_hat_first_half = self.h_hat_W(x)
        h_hat_last_half = self.h_hat_U(h)

        # layer norm
        h_hat_first_half = self.ln_cell_1(h_hat_first_half)
        h_hat_last_half = self.ln_cell_2(h_hat_last_half)

        h_hat = torch.tanh(h_hat_first_half + r_t * h_hat_last_half)

        h_t = (1.0 - z_t) * h + z_t * h_hat

        h_t = h_t.view(h_t.size(0), -1)
        return h_t


class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.cells = nn.ModuleList([
            LayerNormGRUCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):  # x: (B, T, F)
        B, T, _ = x.shape
        h = [
            torch.zeros(B, self.hidden_size, device=x.device)
            for _ in range(self.num_layers)
        ]

        outputs = []
        for t in range(T):
            input_t = x[:, t, :]
            for l in range(self.num_layers):
                h[l] = self.cells[l](input_t, h[l])
                input_t = h[l]
            outputs.append(h[-1])  # last layer output at time t

        out = torch.stack(outputs, dim=1)  # (B, T, H)
        out = self.dropout(out)
        out = self.fc(out)                 # (B, T, 1)
        return out.squeeze(-1)             # (B, T)

class BasicGRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.gru = nn.GRU(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            dropout=dropout,
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):  # x: (B, T, F)
        out, _ = self.gru(x)        # (B, T, H)
        out = self.dropout(out)
        out = self.fc(out)          # (B, T, 1)
        return out.squeeze(-1)      # (B, T)


def apply_chrono_init(gru_module: nn.GRU, Tmax: int, Tmin: float = 1.0, debug: bool = False) -> nn.GRU:
    """Apply chrono initialization to an nn.GRU module and return it."""
    if chrono_init is None:
        raise ImportError("chrono_initialization.py not found or init_gru import failed.")
    if Tmax is None:
        raise ValueError("Tmax must be provided for chrono initialization")

    gru_module = chrono_init(gru_module, Tmax=Tmax, Tmin=Tmin)

    if debug:
        with torch.no_grad():
            for name, param in gru_module.named_parameters():
                if "bias_ih" in name:
                    h = param.shape[0] // 3
                    reset_gate_bias = param[0:h]
                    update_gate_bias = param[h:2*h]
                    print("[Chrono] Update gate bias (first 5):", update_gate_bias[:5].cpu().numpy())
                    print("[Chrono] Reset gate bias (first 5):", reset_gate_bias[:5].cpu().numpy())
                    break

    return gru_module

class LayerNormGRUCellWithChrono(torch.nn.Module):
    """
    Same as LayerNormGRUCell, but adds a post-LN gate bias (2*hidden)
    that enables Chrono-style initialization.
    """
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()

        self.ln_i2h = torch.nn.LayerNorm(2 * hidden_size, elementwise_affine=False)
        self.ln_h2h = torch.nn.LayerNorm(2 * hidden_size, elementwise_affine=False)
        self.ln_cell_1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.ln_cell_2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False)

        self.i2h = torch.nn.Linear(input_size, 2 * hidden_size, bias=bias)
        self.h2h = torch.nn.Linear(hidden_size, 2 * hidden_size, bias=bias)
        self.h_hat_W = torch.nn.Linear(input_size, hidden_size, bias=bias)
        self.h_hat_U = torch.nn.Linear(hidden_size, hidden_size, bias=bias)

        self.hidden_size = hidden_size

        #  post-LN bias for gates (enables chrono init)
        self.gate_bias = nn.Parameter(torch.zeros(2 * hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)
        # keep chrono bias neutral by default
        with torch.no_grad():
            self.gate_bias.zero_()

    def forward(self, x, h):
        h = h.view(h.size(0), -1)
        x = x.view(x.size(0), -1)

        i2h = self.i2h(x)
        h2h = self.h2h(h)

        i2h = self.ln_i2h(i2h)
        h2h = self.ln_h2h(h2h)

        # ONLY difference vs LayerNorm baseline:
        preact = i2h + h2h + self.gate_bias

        gates = preact.sigmoid()
        z_t = gates[:, :self.hidden_size]
        r_t = gates[:, -self.hidden_size:]

        h_hat_first_half = self.h_hat_W(x)
        h_hat_last_half = self.h_hat_U(h)

        h_hat_first_half = self.ln_cell_1(h_hat_first_half)
        h_hat_last_half = self.ln_cell_2(h_hat_last_half)

        h_hat = torch.tanh(h_hat_first_half + r_t * h_hat_last_half)
        h_t = (1.0 - z_t) * h + z_t * h_hat
        return h_t.view(h_t.size(0), -1)

class GRUModelLayerNormChrono(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.cells = nn.ModuleList([
            LayerNormGRUCellWithChrono(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        B, T, _ = x.shape
        h = [torch.zeros(B, self.hidden_size, device=x.device) for _ in range(self.num_layers)]
        outputs = []
        for t in range(T):
            input_t = x[:, t, :]
            for l in range(self.num_layers):
                h[l] = self.cells[l](input_t, h[l])
                input_t = h[l]
            outputs.append(h[-1])
        out = torch.stack(outputs, dim=1)
        out = self.dropout(out)
        out = self.fc(out)
        return out.squeeze(-1)
        
def apply_chrono_init_layernorm_chrono_model(model: nn.Module, Tmax: int, Tmin: float = 1.0, debug: bool = False) -> nn.Module:
    """
    Applies chrono init to LayerNormGRUCellWithChrono (post-LN gate_bias).
    """
    if Tmax is None:
        raise ValueError("Tmax must be provided")

    for cell in model.cells:
        H = cell.hidden_size
        device = cell.gate_bias.device

        u = (Tmax - Tmin) * torch.rand(H, device=device) + Tmin
        bz = -torch.log(u)

        with torch.no_grad():
            cell.gate_bias[:H].copy_(bz)   # z-gate
            cell.gate_bias[H:].zero_()     # r-gate

        if debug:
            print("[Chrono-LN] z-bias (first 5):", cell.gate_bias[:5].detach().cpu().numpy())
            print("[Chrono-LN] r-bias (first 5):", cell.gate_bias[H:H+5].detach().cpu().numpy())

    return model
